![]() |
![]() |
![]() |
Apresentamos o CodeGemma, uma coleção de modelos de código aberto baseados nos modelos do Gemma do Google DeepMind (Gemma Team et al., 2024). O CodeGemma é uma família de modelos abertos leves e de última geração criados com base na mesma pesquisa e tecnologia usadas para criar os modelos do Gemini.
Seguindo os modelos pré-treinados do Gemma, os modelos CodeGemma são treinados em mais de 500 a 1.000 bilhões de tokens de código, usando as mesmas arquiteturas da família de modelos do Gemma. Como resultado, os modelos do CodeGemma alcançam a melhor performance de código em tarefas de conclusão e geração, mantendo habilidades de compreensão e raciocínio em grande escala.
O CodeGemma tem três variantes:
- Um modelo de código pré-treinado de 7B
- Um modelo de código ajustado por instruções do 7B
- Um modelo 2B, treinado especificamente para preenchimento de código e geração aberta.
Este guia orienta você a usar o modelo CodeGemma com o Flax para uma tarefa de conclusão de código.
Configuração
1. Configurar o acesso ao Kaggle para o CodeGemma
Para concluir este tutorial, primeiro siga as instruções de configuração em Configuração do Gemma, que mostram como fazer o seguinte:
- Acesse o CodeGemma em kaggle.com.
- Selecione um ambiente de execução do Colab com recursos suficientes (a GPU T4 tem memória insuficiente, use a TPU v2) para executar o modelo CodeGemma.
- Gere e configure um nome de usuário e uma chave de API do Kaggle.
Depois de concluir a configuração do Gemma, passe para a próxima seção, em que você vai definir variáveis de ambiente para o ambiente do Colab.
2. Defina as variáveis de ambiente
Defina as variáveis de ambiente para KAGGLE_USERNAME
e KAGGLE_KEY
. Quando receber a mensagem "Conceder acesso?", aceite para fornecer acesso ao segredo.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. Instalar a biblioteca gemma
No momento, a aceleração de hardware sem custo financeiro do Colab é insufficient para executar este notebook. Se você estiver usando o pagamento por uso ou o Colab Pro, clique em Editar > Configurações do notebook > Selecione GPU A100 > Salvar para ativar a aceleração de hardware.
Em seguida, instale a biblioteca gemma
do Google DeepMind em github.com/google-deepmind/gemma
. Se você receber um erro sobre o "resolvedor de dependências do pip", geralmente é possível ignorá-lo.
pip install -q git+https://github.com/google-deepmind/gemma.git
4. Importar bibliotecas
Este bloco de notas usa o Gemma (que usa o Flax para criar as camadas da rede neural) e o SentencePiece (para tokenização).
import os
from gemma.deprecated import params as params_lib
from gemma.deprecated import sampler as sampler_lib
from gemma.deprecated import transformer as transformer_lib
import sentencepiece as spm
Carregar o modelo do CodeGemma
Carregue o modelo do CodeGemma com kagglehub.model_download
, que recebe três argumentos:
handle
: o identificador do modelo do Kagglepath
: (string opcional) o caminho localforce_download
: (booleano opcional) força o novo download do modelo.
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7) Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download... 100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3
Verifique o local dos pesos do modelo e do tokenizer e defina as variáveis de caminho. O diretório do tokenizer vai estar no diretório principal em que você fez o download do modelo, e os pesos do modelo vão estar em um subdiretório. Exemplo:
- O arquivo de tokenização
spm.model
vai estar em/LOCAL/PATH/TO/codegemma/flax/2b-pt/3
- O checkpoint do modelo será em
/LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model
Realizar amostragem/inferência
Carregue e formate o checkpoint do modelo do CodeGemma com o método gemma.params.load_and_format_params
:
params = params_lib.load_and_format_params(CKPT_PATH)
Carregue o tokenizer do CodeGemma, criado usando sentencepiece.SentencePieceProcessor
:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
Para carregar automaticamente a configuração correta do ponto de verificação do modelo do CodeGemma, use gemma.deprecated.transformer.TransformerConfig
. O argumento cache_size
é o número de etapas de tempo no cache Transformer
do CodeGemma. Em seguida, instancie o modelo do CodeGemma como model_2b
com gemma.deprecated.transformer.Transformer
(que herda de flax.linen.Module
).
transformer_config = transformer_lib.TransformerConfig.from_params(
params,
cache_size=1024
)
transformer = transformer_lib.Transformer(config=transformer_config)
Crie um sampler
com gemma.sampler.Sampler
. Ele usa o ponto de verificação do modelo CodeGemma e o tokenizer.
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer']
)
Crie algumas variáveis para representar os tokens de preenchimento (fim) e crie algumas funções auxiliares para formatar o comando e a saída gerada.
Por exemplo, confira o seguinte código:
def function(string):
assert function('asdf') == 'fdsa'
Queremos preencher o function
para que a declaração mantenha True
. Nesse caso, o prefixo seria:
"def function(string):\n"
E o sufixo seria:
"assert function('asdf') == 'fdsa'"
Em seguida, formatamos isso em uma solicitação como PREFIXO-SUFIXO-MEIO (a seção do meio que precisa ser preenchida está sempre no final da solicitação):
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"
def format_completion_prompt(before, after):
print(f"\nORIGINAL PROMPT:\n{before}{after}")
prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
return prompt
def format_generated_output(before, after, output):
print(f"\nGENERATED OUTPUT:\n{repr(output)}")
formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
return formatted_output
Crie uma solicitação e realize a inferência. Especifique o texto do prefixo before
e o texto do sufixo after
e gere o comando formatado usando a função auxiliar format_completion prompt
.
É possível ajustar total_generation_steps
(o número de etapas realizadas ao gerar uma resposta. Este exemplo usa 100
para preservar a memória do host).
before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: def function(string): assert function('asdf') == 'fdsa' FORMATTED PROMPT: "<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>" GENERATED OUTPUT: ' return string[::-1]\n\n<|file_separator|>' FILL-IN COMPLETION: def function(string): return string[::-1] assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n sys.exit(0)"""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import if __name__ == "__main__": sys.exit(0) FORMATTED PROMPT: '<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n sys.exit(0)<|fim_middle|>' GENERATED OUTPUT: 'sys\n<|file_separator|>' FILL-IN COMPLETION: import sys if __name__ == "__main__": sys.exit(0)
before = """import numpy as np
def reflect(matrix):
# horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import numpy as np def reflect(matrix): # horizontally reflect a matrix FORMATTED PROMPT: '<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>' GENERATED OUTPUT: ' return np.flip(matrix, axis=1)\n<|file_separator|>' FILL-IN COMPLETION: import numpy as np def reflect(matrix): # horizontally reflect a matrix return np.flip(matrix, axis=1)
Saiba mais
- Saiba mais sobre a biblioteca
gemma
do Google DeepMind no GitHub (link em inglês), que contém docstrings de módulos usados neste tutorial, comogemma.params
,gemma.deprecated.transformer
egemma.sampler
. - As bibliotecas a seguir têm os próprios sites de documentação: core JAX, Flax e Orbax.
- Para a documentação do
sentencepiece
tokenizer/detokenizer, consulte o repositório do GitHubsentencepiece
do Google. - Para conferir a documentação de
kagglehub
, consulteREADME.md
no repositório do GitHubkagglehub
do Kaggle. - Saiba como usar modelos Gemma com a Vertex AI do Google Cloud.
- Se você estiver usando TPUs do Google Cloud (v3-8 e mais recentes), atualize também para o pacote
jax[tpu]
mais recente (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
), reinicie o ambiente de execução e verifique se as versõesjax
ejaxlib
correspondem (!pip list | grep jax
). Isso pode impedir oRuntimeError
que pode surgir devido à incompatibilidade de versõesjaxlib
ejax
. Para mais instruções de instalação do JAX, consulte os documentos do JAX.